// Copyright (c) 2012 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "net/tools/quic/quic_dispatcher.h"

#include <memory>
#include <ostream>
#include <string>

#include "base/macros.h"
#include "base/strings/string_piece.h"
#include "net/quic/crypto/crypto_handshake.h"
#include "net/quic/crypto/quic_crypto_server_config.h"
#include "net/quic/crypto/quic_random.h"
#include "net/quic/quic_chromium_connection_helper.h"
#include "net/quic/quic_crypto_stream.h"
#include "net/quic/quic_flags.h"
#include "net/quic/quic_utils.h"
#include "net/quic/test_tools/crypto_test_utils.h"
#include "net/quic/test_tools/quic_test_utils.h"
#include "net/tools/epoll_server/epoll_server.h"
#include "net/tools/quic/quic_epoll_alarm_factory.h"
#include "net/tools/quic/quic_epoll_connection_helper.h"
#include "net/tools/quic/quic_packet_writer_wrapper.h"
#include "net/tools/quic/quic_simple_server_session_helper.h"
#include "net/tools/quic/quic_time_wait_list_manager.h"
#include "net/tools/quic/test_tools/mock_quic_time_wait_list_manager.h"
#include "net/tools/quic/test_tools/quic_dispatcher_peer.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"

using base::StringPiece;
using net::EpollServer;
using net::test::ConstructEncryptedPacket;
using net::test::CryptoTestUtils;
using net::test::MockQuicConnection;
using net::test::MockQuicConnectionHelper;
using net::test::ValueRestore;
using std::string;
using std::vector;
using testing::_;
using testing::DoAll;
using testing::InSequence;
using testing::Invoke;
using testing::WithoutArgs;

namespace net {
namespace test {
    namespace {

        class TestQuicSpdyServerSession : public QuicServerSessionBase {
        public:
            TestQuicSpdyServerSession(const QuicConfig& config,
                QuicConnection* connection,
                const QuicCryptoServerConfig* crypto_config,
                QuicCompressedCertsCache* compressed_certs_cache)
                : QuicServerSessionBase(config,
                    connection,
                    nullptr,
                    nullptr,
                    crypto_config,
                    compressed_certs_cache)
                , crypto_stream_(QuicServerSessionBase::GetCryptoStream())
            {
            }
            ~TestQuicSpdyServerSession() override {};

            MOCK_METHOD3(OnConnectionClosed,
                void(QuicErrorCode error,
                    const string& error_details,
                    ConnectionCloseSource source));
            MOCK_METHOD1(CreateIncomingDynamicStream, QuicSpdyStream*(QuicStreamId id));
            MOCK_METHOD1(CreateOutgoingDynamicStream,
                QuicSpdyStream*(SpdyPriority priority));

            QuicCryptoServerStreamBase* CreateQuicCryptoServerStream(
                const QuicCryptoServerConfig* crypto_config,
                QuicCompressedCertsCache* compressed_certs_cache) override
            {
                return new QuicCryptoServerStream(
                    crypto_config, compressed_certs_cache,
                    FLAGS_enable_quic_stateless_reject_support, this);
            }

            void SetCryptoStream(QuicCryptoServerStream* crypto_stream)
            {
                crypto_stream_ = crypto_stream;
            }

            QuicCryptoServerStreamBase* GetCryptoStream() override
            {
                return crypto_stream_;
            }

        private:
            QuicCryptoServerStreamBase* crypto_stream_;

            DISALLOW_COPY_AND_ASSIGN(TestQuicSpdyServerSession);
        };

        class TestDispatcher : public QuicDispatcher {
        public:
            TestDispatcher(const QuicConfig& config,
                const QuicCryptoServerConfig* crypto_config,
                EpollServer* eps)
                : QuicDispatcher(
                    config,
                    crypto_config,
                    QuicSupportedVersions(),
                    std::unique_ptr<QuicEpollConnectionHelper>(
                        new QuicEpollConnectionHelper(eps, QuicAllocator::BUFFER_POOL)),
                    std::unique_ptr<QuicServerSessionBase::Helper>(
                        new QuicSimpleServerSessionHelper(QuicRandom::GetInstance())),
                    std::unique_ptr<QuicEpollAlarmFactory>(
                        new QuicEpollAlarmFactory(eps)))
            {
            }

            MOCK_METHOD2(CreateQuicSession,
                QuicServerSessionBase*(QuicConnectionId connection_id,
                    const IPEndPoint& client_address));

            using QuicDispatcher::current_client_address;
            using QuicDispatcher::current_server_address;
        };

        // A Connection class which unregisters the session from the dispatcher when
        // sending connection close.
        // It'd be slightly more realistic to do this from the Session but it would
        // involve a lot more mocking.
        class MockServerConnection : public MockQuicConnection {
        public:
            MockServerConnection(QuicConnectionId connection_id,
                MockQuicConnectionHelper* helper,
                MockAlarmFactory* alarm_factory,
                QuicDispatcher* dispatcher)
                : MockQuicConnection(connection_id,
                    helper,
                    alarm_factory,
                    Perspective::IS_SERVER)
                , dispatcher_(dispatcher)
            {
            }

            void UnregisterOnConnectionClosed()
            {
                LOG(ERROR) << "Unregistering " << connection_id();
                dispatcher_->OnConnectionClosed(connection_id(), QUIC_NO_ERROR,
                    "Unregistering.");
            }

        private:
            QuicDispatcher* dispatcher_;
        };

        QuicServerSessionBase* CreateSession(
            QuicDispatcher* dispatcher,
            const QuicConfig& config,
            QuicConnectionId connection_id,
            const IPEndPoint& client_address,
            MockQuicConnectionHelper* helper,
            MockAlarmFactory* alarm_factory,
            const QuicCryptoServerConfig* crypto_config,
            QuicCompressedCertsCache* compressed_certs_cache,
            TestQuicSpdyServerSession** session)
        {
            MockServerConnection* connection = new MockServerConnection(
                connection_id, helper, alarm_factory, dispatcher);
            *session = new TestQuicSpdyServerSession(config, connection, crypto_config,
                compressed_certs_cache);
            connection->set_visitor(*session);
            ON_CALL(*connection, CloseConnection(_, _, _))
                .WillByDefault(WithoutArgs(Invoke(
                    connection, &MockServerConnection::UnregisterOnConnectionClosed)));
            EXPECT_CALL(*reinterpret_cast<MockQuicConnection*>((*session)->connection()),
                ProcessUdpPacket(_, client_address, _));

            return *session;
        }

        class QuicDispatcherTest : public ::testing::Test {
        public:
            QuicDispatcherTest()
                : helper_(&eps_, QuicAllocator::BUFFER_POOL)
                , alarm_factory_(&eps_)
                , crypto_config_(QuicCryptoServerConfig::TESTING,
                      QuicRandom::GetInstance(),
                      CryptoTestUtils::ProofSourceForTesting())
                , dispatcher_(config_, &crypto_config_, &eps_)
                , time_wait_list_manager_(nullptr)
                , session1_(nullptr)
                , session2_(nullptr)
            {
                dispatcher_.InitializeWithWriter(new QuicDefaultPacketWriter(1));
            }

            ~QuicDispatcherTest() override { }

            MockQuicConnection* connection1()
            {
                return reinterpret_cast<MockQuicConnection*>(session1_->connection());
            }

            MockQuicConnection* connection2()
            {
                return reinterpret_cast<MockQuicConnection*>(session2_->connection());
            }

            // Process a packet with an 8 byte connection id,
            // 6 byte packet number, default path id, and packet number 1,
            // using the first supported version.
            void ProcessPacket(IPEndPoint client_address,
                QuicConnectionId connection_id,
                bool has_version_flag,
                bool has_multipath_flag,
                const string& data)
            {
                ProcessPacket(client_address, connection_id, has_version_flag,
                    has_multipath_flag, data, PACKET_8BYTE_CONNECTION_ID,
                    PACKET_6BYTE_PACKET_NUMBER);
            }

            // Process a packet with a default path id, and packet number 1,
            // using the first supported version.
            void ProcessPacket(IPEndPoint client_address,
                QuicConnectionId connection_id,
                bool has_version_flag,
                bool has_multipath_flag,
                const string& data,
                QuicConnectionIdLength connection_id_length,
                QuicPacketNumberLength packet_number_length)
            {
                ProcessPacket(client_address, connection_id, has_version_flag,
                    has_multipath_flag, data, connection_id_length,
                    packet_number_length, kDefaultPathId, 1);
            }

            // Process a packet using the first supported version.
            void ProcessPacket(IPEndPoint client_address,
                QuicConnectionId connection_id,
                bool has_version_flag,
                bool has_multipath_flag,
                const string& data,
                QuicConnectionIdLength connection_id_length,
                QuicPacketNumberLength packet_number_length,
                QuicPathId path_id,
                QuicPacketNumber packet_number)
            {
                ProcessPacket(client_address, connection_id, has_version_flag,
                    QuicSupportedVersions().front(), data, connection_id_length,
                    packet_number_length, packet_number);
            }

            // Processes a packet.
            void ProcessPacket(IPEndPoint client_address,
                QuicConnectionId connection_id,
                bool has_version_flag,
                QuicVersion version,
                const string& data,
                QuicConnectionIdLength connection_id_length,
                QuicPacketNumberLength packet_number_length,
                QuicPacketNumber packet_number)
            {
                QuicVersionVector versions(SupportedVersions(version));
                std::unique_ptr<QuicEncryptedPacket> packet(ConstructEncryptedPacket(
                    connection_id, has_version_flag, false, false, 0, packet_number, data,
                    connection_id_length, packet_number_length, &versions));
                std::unique_ptr<QuicReceivedPacket> received_packet(
                    ConstructReceivedPacket(*packet, helper_.GetClock()->Now()));

                data_ = string(packet->data(), packet->length());
                dispatcher_.ProcessPacket(server_address_, client_address,
                    *received_packet);
            }

            void ValidatePacket(const QuicEncryptedPacket& packet)
            {
                EXPECT_EQ(data_.length(), packet.AsStringPiece().length());
                EXPECT_EQ(data_, packet.AsStringPiece());
            }

            void CreateTimeWaitListManager()
            {
                time_wait_list_manager_ = new MockTimeWaitListManager(QuicDispatcherPeer::GetWriter(&dispatcher_),
                    &dispatcher_, &helper_, &alarm_factory_);
                // dispatcher_ takes the ownership of time_wait_list_manager_.
                QuicDispatcherPeer::SetTimeWaitListManager(&dispatcher_,
                    time_wait_list_manager_);
            }

            string SerializeCHLO()
            {
                CryptoHandshakeMessage client_hello;
                client_hello.set_tag(kCHLO);
                return client_hello.GetSerialized().AsStringPiece().as_string();
            }

            EpollServer eps_;
            QuicEpollConnectionHelper helper_;
            MockQuicConnectionHelper mock_helper_;
            QuicEpollAlarmFactory alarm_factory_;
            MockAlarmFactory mock_alarm_factory_;
            QuicConfig config_;
            QuicCryptoServerConfig crypto_config_;
            IPEndPoint server_address_;
            TestDispatcher dispatcher_;
            MockTimeWaitListManager* time_wait_list_manager_;
            TestQuicSpdyServerSession* session1_;
            TestQuicSpdyServerSession* session2_;
            string data_;
        };

        TEST_F(QuicDispatcherTest, ProcessPackets)
        {
            IPEndPoint client_address(net::test::Loopback4(), 1);
            server_address_ = IPEndPoint(net::test::Any4(), 5);

            EXPECT_CALL(dispatcher_, CreateQuicSession(1, client_address))
                .WillOnce(testing::Return(CreateSession(
                    &dispatcher_, config_, 1, client_address, &mock_helper_,
                    &mock_alarm_factory_, &crypto_config_,
                    QuicDispatcherPeer::GetCache(&dispatcher_), &session1_)));
            ProcessPacket(client_address, 1, true, false, SerializeCHLO());
            EXPECT_EQ(client_address, dispatcher_.current_client_address());
            EXPECT_EQ(server_address_, dispatcher_.current_server_address());

            EXPECT_CALL(dispatcher_, CreateQuicSession(2, client_address))
                .WillOnce(testing::Return(CreateSession(
                    &dispatcher_, config_, 2, client_address, &mock_helper_,
                    &mock_alarm_factory_, &crypto_config_,
                    QuicDispatcherPeer::GetCache(&dispatcher_), &session2_)));
            ProcessPacket(client_address, 2, true, false, SerializeCHLO());

            EXPECT_CALL(*reinterpret_cast<MockQuicConnection*>(session1_->connection()),
                ProcessUdpPacket(_, _, _))
                .Times(1)
                .WillOnce(testing::WithArgs<2>(
                    Invoke(this, &QuicDispatcherTest::ValidatePacket)));
            ProcessPacket(client_address, 1, false, false, "data");
        }

        TEST_F(QuicDispatcherTest, StatelessVersionNegotiation)
        {
            IPEndPoint client_address(net::test::Loopback4(), 1);
            server_address_ = IPEndPoint(net::test::Any4(), 5);

            EXPECT_CALL(dispatcher_, CreateQuicSession(1, client_address)).Times(0);
            QuicVersion version = static_cast<QuicVersion>(QuicVersionMin() - 1);
            ProcessPacket(client_address, 1, true, version, SerializeCHLO(),
                PACKET_8BYTE_CONNECTION_ID, PACKET_6BYTE_PACKET_NUMBER, 1);
        }

        TEST_F(QuicDispatcherTest, Shutdown)
        {
            IPEndPoint client_address(net::test::Loopback4(), 1);

            EXPECT_CALL(dispatcher_, CreateQuicSession(_, client_address))
                .WillOnce(testing::Return(CreateSession(
                    &dispatcher_, config_, 1, client_address, &mock_helper_,
                    &mock_alarm_factory_, &crypto_config_,
                    QuicDispatcherPeer::GetCache(&dispatcher_), &session1_)));

            ProcessPacket(client_address, 1, true, false, SerializeCHLO());

            EXPECT_CALL(*reinterpret_cast<MockQuicConnection*>(session1_->connection()),
                CloseConnection(QUIC_PEER_GOING_AWAY, _, _));

            dispatcher_.Shutdown();
        }

        TEST_F(QuicDispatcherTest, TimeWaitListManager)
        {
            CreateTimeWaitListManager();

            // Create a new session.
            IPEndPoint client_address(net::test::Loopback4(), 1);
            QuicConnectionId connection_id = 1;
            EXPECT_CALL(dispatcher_, CreateQuicSession(connection_id, client_address))
                .WillOnce(testing::Return(CreateSession(
                    &dispatcher_, config_, connection_id, client_address, &mock_helper_,
                    &mock_alarm_factory_, &crypto_config_,
                    QuicDispatcherPeer::GetCache(&dispatcher_), &session1_)));
            ProcessPacket(client_address, connection_id, true, false, SerializeCHLO());

            // Close the connection by sending public reset packet.
            QuicPublicResetPacket packet;
            packet.public_header.connection_id = connection_id;
            packet.public_header.reset_flag = true;
            packet.public_header.version_flag = false;
            packet.rejected_packet_number = 19191;
            packet.nonce_proof = 132232;
            std::unique_ptr<QuicEncryptedPacket> encrypted(
                QuicFramer::BuildPublicResetPacket(packet));
            std::unique_ptr<QuicReceivedPacket> received(
                ConstructReceivedPacket(*encrypted, helper_.GetClock()->Now()));
            EXPECT_CALL(*session1_, OnConnectionClosed(QUIC_PUBLIC_RESET, _, ConnectionCloseSource::FROM_PEER))
                .Times(1)
                .WillOnce(WithoutArgs(Invoke(
                    reinterpret_cast<MockServerConnection*>(session1_->connection()),
                    &MockServerConnection::UnregisterOnConnectionClosed)));
            EXPECT_CALL(*reinterpret_cast<MockQuicConnection*>(session1_->connection()),
                ProcessUdpPacket(_, _, _))
                .WillOnce(
                    Invoke(reinterpret_cast<MockQuicConnection*>(session1_->connection()),
                        &MockQuicConnection::ReallyProcessUdpPacket));
            dispatcher_.ProcessPacket(IPEndPoint(), client_address, *received);
            EXPECT_TRUE(time_wait_list_manager_->IsConnectionIdInTimeWait(connection_id));

            // Dispatcher forwards subsequent packets for this connection_id to the time
            // wait list manager.
            EXPECT_CALL(*time_wait_list_manager_,
                ProcessPacket(_, _, connection_id, _, _))
                .Times(1);
            EXPECT_CALL(*time_wait_list_manager_, AddConnectionIdToTimeWait(_, _, _, _))
                .Times(0);
            ProcessPacket(client_address, connection_id, true, false, "data");
        }

        TEST_F(QuicDispatcherTest, NoVersionPacketToTimeWaitListManager)
        {
            CreateTimeWaitListManager();

            IPEndPoint client_address(net::test::Loopback4(), 1);
            QuicConnectionId connection_id = 1;
            // Dispatcher forwards all packets for this connection_id to the time wait
            // list manager.
            EXPECT_CALL(dispatcher_, CreateQuicSession(_, _)).Times(0);
            EXPECT_CALL(*time_wait_list_manager_,
                ProcessPacket(_, _, connection_id, _, _))
                .Times(1);
            EXPECT_CALL(*time_wait_list_manager_, AddConnectionIdToTimeWait(_, _, _, _))
                .Times(1);
            ProcessPacket(client_address, connection_id, false, false, SerializeCHLO());
        }

        TEST_F(QuicDispatcherTest, ProcessPacketWithZeroPort)
        {
            CreateTimeWaitListManager();

            IPEndPoint client_address(net::test::Loopback4(), 0);
            server_address_ = IPEndPoint(net::test::Any4(), 5);

            // dispatcher_ should drop this packet.
            EXPECT_CALL(dispatcher_, CreateQuicSession(1, client_address)).Times(0);
            EXPECT_CALL(*time_wait_list_manager_, ProcessPacket(_, _, _, _, _)).Times(0);
            EXPECT_CALL(*time_wait_list_manager_, AddConnectionIdToTimeWait(_, _, _, _))
                .Times(0);
            ProcessPacket(client_address, 1, true, false, SerializeCHLO());
        }

        TEST_F(QuicDispatcherTest, OKSeqNoPacketProcessed)
        {
            IPEndPoint client_address(net::test::Loopback4(), 1);
            QuicConnectionId connection_id = 1;
            server_address_ = IPEndPoint(net::test::Any4(), 5);

            EXPECT_CALL(dispatcher_, CreateQuicSession(1, client_address))
                .WillOnce(testing::Return(CreateSession(
                    &dispatcher_, config_, 1, client_address, &mock_helper_,
                    &mock_alarm_factory_, &crypto_config_,
                    QuicDispatcherPeer::GetCache(&dispatcher_), &session1_)));
            // A packet whose packet number is the largest that is allowed to start a
            // connection.
            ProcessPacket(client_address, connection_id, true, false, SerializeCHLO(),
                PACKET_8BYTE_CONNECTION_ID, PACKET_6BYTE_PACKET_NUMBER,
                kDefaultPathId,
                QuicDispatcher::kMaxReasonableInitialPacketNumber);
            EXPECT_EQ(client_address, dispatcher_.current_client_address());
            EXPECT_EQ(server_address_, dispatcher_.current_server_address());
        }

        TEST_F(QuicDispatcherTest, TooBigSeqNoPacketToTimeWaitListManager)
        {
            CreateTimeWaitListManager();

            IPEndPoint client_address(net::test::Loopback4(), 1);
            QuicConnectionId connection_id = 1;
            // Dispatcher forwards this packet for this connection_id to the time wait
            // list manager.
            EXPECT_CALL(dispatcher_, CreateQuicSession(_, _)).Times(0);
            EXPECT_CALL(*time_wait_list_manager_,
                ProcessPacket(_, _, connection_id, _, _))
                .Times(1);
            EXPECT_CALL(*time_wait_list_manager_, AddConnectionIdToTimeWait(_, _, _, _))
                .Times(1);
            // A packet whose packet number is one to large to be allowed to start a
            // connection.
            ProcessPacket(client_address, connection_id, true, false, SerializeCHLO(),
                PACKET_8BYTE_CONNECTION_ID, PACKET_6BYTE_PACKET_NUMBER,
                kDefaultPathId,
                QuicDispatcher::kMaxReasonableInitialPacketNumber + 1);
        }

        // Enables mocking of the handshake-confirmation for stateless rejects.
        class MockQuicCryptoServerStream : public QuicCryptoServerStream {
        public:
            MockQuicCryptoServerStream(const QuicCryptoServerConfig& crypto_config,
                QuicCompressedCertsCache* compressed_certs_cache,
                QuicServerSessionBase* session)
                : QuicCryptoServerStream(&crypto_config,
                    compressed_certs_cache,
                    FLAGS_enable_quic_stateless_reject_support,
                    session)
            {
            }
            void set_handshake_confirmed_for_testing(bool handshake_confirmed)
            {
                handshake_confirmed_ = handshake_confirmed;
            }

        private:
            DISALLOW_COPY_AND_ASSIGN(MockQuicCryptoServerStream);
        };

        struct StatelessRejectTestParams {
            StatelessRejectTestParams(bool enable_stateless_rejects_via_flag,
                bool client_supports_statelesss_rejects,
                bool crypto_handshake_successful)
                : enable_stateless_rejects_via_flag(enable_stateless_rejects_via_flag)
                , client_supports_statelesss_rejects(client_supports_statelesss_rejects)
                , crypto_handshake_successful(crypto_handshake_successful)
            {
            }

            friend std::ostream& operator<<(std::ostream& os,
                const StatelessRejectTestParams& p)
            {
                os << "{  enable_stateless_rejects_via_flag: "
                   << p.enable_stateless_rejects_via_flag << std::endl;
                os << " client_supports_statelesss_rejects: "
                   << p.client_supports_statelesss_rejects << std::endl;
                os << "  crypto_handshake_successful: " << p.crypto_handshake_successful
                   << " }";
                return os;
            }

            // This only enables the stateless reject feature via the feature-flag.
            // This should be a no-op if the peer does not support them.
            bool enable_stateless_rejects_via_flag;
            // Whether or not the client supports stateless rejects.
            bool client_supports_statelesss_rejects;
            // Should the initial crypto handshake succeed or not.
            bool crypto_handshake_successful;
        };

        // Constructs various test permutations for stateless rejects.
        vector<StatelessRejectTestParams> GetStatelessRejectTestParams()
        {
            vector<StatelessRejectTestParams> params;
            for (bool enable_stateless_rejects_via_flag : { true, false }) {
                for (bool client_supports_statelesss_rejects : { true, false }) {
                    for (bool crypto_handshake_successful : { true, false }) {
                        params.push_back(StatelessRejectTestParams(
                            enable_stateless_rejects_via_flag,
                            client_supports_statelesss_rejects, crypto_handshake_successful));
                    }
                }
            }
            return params;
        }

        class QuicDispatcherStatelessRejectTest
            : public QuicDispatcherTest,
              public ::testing::WithParamInterface<StatelessRejectTestParams> {
        public:
            QuicDispatcherStatelessRejectTest()
                : crypto_stream1_(nullptr)
            {
            }

            ~QuicDispatcherStatelessRejectTest() override
            {
                if (crypto_stream1_) {
                    delete crypto_stream1_;
                }
            }

            // This test setup assumes that all testing will be done using
            // crypto_stream1_.
            void SetUp() override
            {
                FLAGS_enable_quic_stateless_reject_support = GetParam().enable_stateless_rejects_via_flag;
            }

            // Returns true or false, depending on whether the server will emit
            // a stateless reject, depending upon the parameters of the test.
            bool ExpectStatelessReject()
            {
                return GetParam().enable_stateless_rejects_via_flag && !GetParam().crypto_handshake_successful && GetParam().client_supports_statelesss_rejects;
            }

            // Sets up dispatcher_, sesession1_, and crypto_stream1_ based on
            // the test parameters.
            QuicServerSessionBase* CreateSessionBasedOnTestParams(
                QuicConnectionId connection_id,
                const IPEndPoint& client_address)
            {
                CreateSession(&dispatcher_, config_, connection_id, client_address,
                    &mock_helper_, &mock_alarm_factory_, &crypto_config_,
                    QuicDispatcherPeer::GetCache(&dispatcher_), &session1_);

                crypto_stream1_ = new MockQuicCryptoServerStream(
                    crypto_config_, QuicDispatcherPeer::GetCache(&dispatcher_), session1_);
                session1_->SetCryptoStream(crypto_stream1_);
                crypto_stream1_->set_handshake_confirmed_for_testing(
                    GetParam().crypto_handshake_successful);
                crypto_stream1_->SetPeerSupportsStatelessRejects(
                    GetParam().client_supports_statelesss_rejects);
                return session1_;
            }

            MockQuicCryptoServerStream* crypto_stream1_;
        };

        // Parameterized test for stateless rejects.  Should test all
        // combinations of enabling/disabling, reject/no-reject for stateless
        // rejects.
        INSTANTIATE_TEST_CASE_P(QuicDispatcherStatelessRejectTests,
            QuicDispatcherStatelessRejectTest,
            ::testing::ValuesIn(GetStatelessRejectTestParams()));

        TEST_P(QuicDispatcherStatelessRejectTest, ParameterizedBasicTest)
        {
            CreateTimeWaitListManager();

            IPEndPoint client_address(net::test::Loopback4(), 1);
            QuicConnectionId connection_id = 1;
            EXPECT_CALL(dispatcher_, CreateQuicSession(connection_id, client_address))
                .WillOnce(testing::Return(
                    CreateSessionBasedOnTestParams(connection_id, client_address)));

            // Process the first packet for the connection.
            ProcessPacket(client_address, connection_id, true, false, SerializeCHLO());
            if (ExpectStatelessReject()) {
                // If this is a stateless reject, the crypto stream will close the
                // connection.
                session1_->connection()->CloseConnection(
                    QUIC_CRYPTO_HANDSHAKE_STATELESS_REJECT, "stateless reject",
                    ConnectionCloseBehavior::SILENT_CLOSE);
            }

            // Send a second packet and check the results.  If this is a stateless reject,
            // the existing connection_id will go on the time-wait list.
            EXPECT_EQ(ExpectStatelessReject(),
                time_wait_list_manager_->IsConnectionIdInTimeWait(connection_id));
            if (ExpectStatelessReject()) {
                // The second packet will be processed on the time-wait list.
                EXPECT_CALL(*time_wait_list_manager_,
                    ProcessPacket(_, _, connection_id, _, _))
                    .Times(1);
            } else {
                // The second packet will trigger a packet-validation
                EXPECT_CALL(*reinterpret_cast<MockQuicConnection*>(session1_->connection()),
                    ProcessUdpPacket(_, _, _))
                    .Times(1)
                    .WillOnce(testing::WithArgs<2>(
                        Invoke(this, &QuicDispatcherTest::ValidatePacket)));
            }
            ProcessPacket(client_address, connection_id, true, false, "data");
        }

        TEST_P(QuicDispatcherStatelessRejectTest, CheapRejects)
        {
            FLAGS_quic_use_cheap_stateless_rejects = true;
            CreateTimeWaitListManager();

            IPEndPoint client_address(net::test::Loopback4(), 1);
            QuicConnectionId connection_id = 1;
            if (GetParam().enable_stateless_rejects_via_flag) {
                EXPECT_CALL(dispatcher_, CreateQuicSession(connection_id, client_address))
                    .Times(0);
            } else {
                EXPECT_CALL(dispatcher_, CreateQuicSession(connection_id, client_address))
                    .WillOnce(testing::Return(
                        CreateSessionBasedOnTestParams(connection_id, client_address)));
            }

            VLOG(1) << "ExpectStatelessReject: " << ExpectStatelessReject();
            VLOG(1) << "Params: " << GetParam();
            // Process the first packet for the connection.
            // clang-format off
  CryptoHandshakeMessage client_hello = CryptoTestUtils::Message(
      "CHLO",
      "AEAD", "AESG",
      "KEXS", "C255",
      "COPT", "SREJ",
      "NONC", "1234567890123456789012",
      "VER\0", "Q025",
      "$padding", static_cast<int>(kClientHelloMinimumSize),
      nullptr);
            // clang-format on

            ProcessPacket(client_address, connection_id, true, false,
                client_hello.GetSerialized().AsStringPiece().as_string());

            if (GetParam().enable_stateless_rejects_via_flag) {
                EXPECT_EQ(true,
                    time_wait_list_manager_->IsConnectionIdInTimeWait(connection_id));
            }
        }

        // Verify the stopgap test: Packets with truncated connection IDs should be
        // dropped.
        class QuicDispatcherTestStrayPacketConnectionId : public QuicDispatcherTest {
        };

        // Packets with truncated connection IDs should be dropped.
        TEST_F(QuicDispatcherTestStrayPacketConnectionId,
            StrayPacketTruncatedConnectionId)
        {
            CreateTimeWaitListManager();

            IPEndPoint client_address(net::test::Loopback4(), 1);
            QuicConnectionId connection_id = 1;
            // Dispatcher drops this packet.
            EXPECT_CALL(dispatcher_, CreateQuicSession(_, _)).Times(0);
            EXPECT_CALL(*time_wait_list_manager_,
                ProcessPacket(_, _, connection_id, _, _))
                .Times(0);
            EXPECT_CALL(*time_wait_list_manager_, AddConnectionIdToTimeWait(_, _, _, _))
                .Times(0);
            ProcessPacket(client_address, connection_id, true, false, "data",
                PACKET_0BYTE_CONNECTION_ID, PACKET_6BYTE_PACKET_NUMBER);
        }

        class BlockingWriter : public QuicPacketWriterWrapper {
        public:
            BlockingWriter()
                : write_blocked_(false)
            {
            }

            bool IsWriteBlocked() const override { return write_blocked_; }
            void SetWritable() override { write_blocked_ = false; }

            WriteResult WritePacket(const char* buffer,
                size_t buf_len,
                const IPAddress& self_client_address,
                const IPEndPoint& peer_client_address,
                PerPacketOptions* options) override
            {
                // It would be quite possible to actually implement this method here with
                // the fake blocked status, but it would be significantly more work in
                // Chromium, and since it's not called anyway, don't bother.
                LOG(DFATAL) << "Not supported";
                return WriteResult();
            }

            bool write_blocked_;
        };

        class QuicDispatcherWriteBlockedListTest : public QuicDispatcherTest {
        public:
            void SetUp() override
            {
                writer_ = new BlockingWriter;
                QuicDispatcherPeer::UseWriter(&dispatcher_, writer_);

                IPEndPoint client_address(net::test::Loopback4(), 1);

                EXPECT_CALL(dispatcher_, CreateQuicSession(_, client_address))
                    .WillOnce(testing::Return(CreateSession(
                        &dispatcher_, config_, 1, client_address, &helper_, &alarm_factory_,
                        &crypto_config_, QuicDispatcherPeer::GetCache(&dispatcher_),
                        &session1_)));
                ProcessPacket(client_address, 1, true, false, SerializeCHLO());

                EXPECT_CALL(dispatcher_, CreateQuicSession(_, client_address))
                    .WillOnce(testing::Return(CreateSession(
                        &dispatcher_, config_, 2, client_address, &helper_, &alarm_factory_,
                        &crypto_config_, QuicDispatcherPeer::GetCache(&dispatcher_),
                        &session2_)));
                ProcessPacket(client_address, 2, true, false, SerializeCHLO());

                blocked_list_ = QuicDispatcherPeer::GetWriteBlockedList(&dispatcher_);
            }

            void TearDown() override
            {
                EXPECT_CALL(*connection1(), CloseConnection(QUIC_PEER_GOING_AWAY, _, _));
                EXPECT_CALL(*connection2(), CloseConnection(QUIC_PEER_GOING_AWAY, _, _));
                dispatcher_.Shutdown();
            }

            void SetBlocked() { writer_->write_blocked_ = true; }

            void BlockConnection2()
            {
                writer_->write_blocked_ = true;
                dispatcher_.OnWriteBlocked(connection2());
            }

        protected:
            MockQuicConnectionHelper helper_;
            MockAlarmFactory alarm_factory_;
            BlockingWriter* writer_;
            QuicDispatcher::WriteBlockedList* blocked_list_;
        };

        TEST_F(QuicDispatcherWriteBlockedListTest, BasicOnCanWrite)
        {
            // No OnCanWrite calls because no connections are blocked.
            dispatcher_.OnCanWrite();

            // Register connection 1 for events, and make sure it's notified.
            SetBlocked();
            dispatcher_.OnWriteBlocked(connection1());
            EXPECT_CALL(*connection1(), OnCanWrite());
            dispatcher_.OnCanWrite();

            // It should get only one notification.
            EXPECT_CALL(*connection1(), OnCanWrite()).Times(0);
            dispatcher_.OnCanWrite();
            EXPECT_FALSE(dispatcher_.HasPendingWrites());
        }

        TEST_F(QuicDispatcherWriteBlockedListTest, OnCanWriteOrder)
        {
            // Make sure we handle events in order.
            InSequence s;
            SetBlocked();
            dispatcher_.OnWriteBlocked(connection1());
            dispatcher_.OnWriteBlocked(connection2());
            EXPECT_CALL(*connection1(), OnCanWrite());
            EXPECT_CALL(*connection2(), OnCanWrite());
            dispatcher_.OnCanWrite();

            // Check the other ordering.
            SetBlocked();
            dispatcher_.OnWriteBlocked(connection2());
            dispatcher_.OnWriteBlocked(connection1());
            EXPECT_CALL(*connection2(), OnCanWrite());
            EXPECT_CALL(*connection1(), OnCanWrite());
            dispatcher_.OnCanWrite();
        }

        TEST_F(QuicDispatcherWriteBlockedListTest, OnCanWriteRemove)
        {
            // Add and remove one connction.
            SetBlocked();
            dispatcher_.OnWriteBlocked(connection1());
            blocked_list_->erase(connection1());
            EXPECT_CALL(*connection1(), OnCanWrite()).Times(0);
            dispatcher_.OnCanWrite();

            // Add and remove one connction and make sure it doesn't affect others.
            SetBlocked();
            dispatcher_.OnWriteBlocked(connection1());
            dispatcher_.OnWriteBlocked(connection2());
            blocked_list_->erase(connection1());
            EXPECT_CALL(*connection2(), OnCanWrite());
            dispatcher_.OnCanWrite();

            // Add it, remove it, and add it back and make sure things are OK.
            SetBlocked();
            dispatcher_.OnWriteBlocked(connection1());
            blocked_list_->erase(connection1());
            dispatcher_.OnWriteBlocked(connection1());
            EXPECT_CALL(*connection1(), OnCanWrite()).Times(1);
            dispatcher_.OnCanWrite();
        }

        TEST_F(QuicDispatcherWriteBlockedListTest, DoubleAdd)
        {
            // Make sure a double add does not necessitate a double remove.
            SetBlocked();
            dispatcher_.OnWriteBlocked(connection1());
            dispatcher_.OnWriteBlocked(connection1());
            blocked_list_->erase(connection1());
            EXPECT_CALL(*connection1(), OnCanWrite()).Times(0);
            dispatcher_.OnCanWrite();

            // Make sure a double add does not result in two OnCanWrite calls.
            SetBlocked();
            dispatcher_.OnWriteBlocked(connection1());
            dispatcher_.OnWriteBlocked(connection1());
            EXPECT_CALL(*connection1(), OnCanWrite()).Times(1);
            dispatcher_.OnCanWrite();
        }

        TEST_F(QuicDispatcherWriteBlockedListTest, OnCanWriteHandleBlock)
        {
            // Finally make sure if we write block on a write call, we stop calling.
            InSequence s;
            SetBlocked();
            dispatcher_.OnWriteBlocked(connection1());
            dispatcher_.OnWriteBlocked(connection2());
            EXPECT_CALL(*connection1(), OnCanWrite())
                .WillOnce(Invoke(this, &QuicDispatcherWriteBlockedListTest::SetBlocked));
            EXPECT_CALL(*connection2(), OnCanWrite()).Times(0);
            dispatcher_.OnCanWrite();

            // And we'll resume where we left off when we get another call.
            EXPECT_CALL(*connection2(), OnCanWrite());
            dispatcher_.OnCanWrite();
        }

        TEST_F(QuicDispatcherWriteBlockedListTest, LimitedWrites)
        {
            // Make sure we call both writers.  The first will register for more writing
            // but should not be immediately called due to limits.
            InSequence s;
            SetBlocked();
            dispatcher_.OnWriteBlocked(connection1());
            dispatcher_.OnWriteBlocked(connection2());
            EXPECT_CALL(*connection1(), OnCanWrite());
            EXPECT_CALL(*connection2(), OnCanWrite())
                .WillOnce(
                    Invoke(this, &QuicDispatcherWriteBlockedListTest::BlockConnection2));
            dispatcher_.OnCanWrite();
            EXPECT_TRUE(dispatcher_.HasPendingWrites());

            // Now call OnCanWrite again, and connection1 should get its second chance
            EXPECT_CALL(*connection2(), OnCanWrite());
            dispatcher_.OnCanWrite();
            EXPECT_FALSE(dispatcher_.HasPendingWrites());
        }

        TEST_F(QuicDispatcherWriteBlockedListTest, TestWriteLimits)
        {
            // Finally make sure if we write block on a write call, we stop calling.
            InSequence s;
            SetBlocked();
            dispatcher_.OnWriteBlocked(connection1());
            dispatcher_.OnWriteBlocked(connection2());
            EXPECT_CALL(*connection1(), OnCanWrite())
                .WillOnce(Invoke(this, &QuicDispatcherWriteBlockedListTest::SetBlocked));
            EXPECT_CALL(*connection2(), OnCanWrite()).Times(0);
            dispatcher_.OnCanWrite();
            EXPECT_TRUE(dispatcher_.HasPendingWrites());

            // And we'll resume where we left off when we get another call.
            EXPECT_CALL(*connection2(), OnCanWrite());
            dispatcher_.OnCanWrite();
            EXPECT_FALSE(dispatcher_.HasPendingWrites());
        }

    } // namespace
} // namespace test
} // namespace net
